fix(distillation): reverse-KL server path NaN on variable completion length#2
Open
fix(distillation): reverse-KL server path NaN on variable completion length#2
Conversation
…length
When ``use_teacher_server=True`` with ``beta > 0`` and ``bs * grad_accum > 1``,
the reverse-KL server path leaked NaN into the backward pass whenever
per-sample completion lengths differed within a batch.
Root cause
----------
``_get_teacher_token_logprobs_from_server`` fills the rectangular (B, T)
output tensor with the TRL house sentinel ``float("-inf")`` at intra-batch
padding positions (the tail of shorter samples). The forward-KL server path
(``_compute_server_forward_kl_loss``) neutralises this via
``torch.where(teacher > -inf, ..., -inf)`` plus a support mask threaded
through ``_add_tail_bucket``; the reverse-KL server path
(``_compute_server_sparse_top_1_divergence_loss``) did not. Both paths
landed in the same commit (huggingface#5407) -- an oversight, not deliberate
asymmetry.
Unmasked, the -inf sentinel produces a teacher distribution [-inf, 0]
after ``_add_tail_bucket`` and +inf in ``_jsd_divergence``'s forward pass
(clamped to ``finfo.max`` by ``nan_to_num``), but NaN in the backward
pass: autograd's chain rule does not respect ``nan_to_num``, so the
pre-clamp +inf leaks through as NaN gradient.
Fix
---
Mirror the forward-KL server path's masking: after the ``isfinite`` checks
that guard required positions, replace the -inf sentinel with a finite
zero at all known padding positions (``labels == -100``) via
``torch.where``. The label mask in ``_reduce_divergence_loss`` still
excludes those positions from the final loss; the new neutralisation
prevents their -inf values from propagating through ``_add_tail_bucket``
and ``_jsd_divergence`` into the autograd graph.
Tests
-----
``tests/experimental/test_distillation_trainer.py`` is new (DistillationTrainer
had zero dedicated tests at v1.1.0):
- Sentinel contract at the server-path getter.
- The reverse-KL mask pattern produces finite forward AND backward on a
ragged batch.
- End-to-end training step under ``per_device_train_batch_size=1``,
``gradient_accumulation_steps=2``, variable completion lengths, with a
mocked ``VLLMClient``. Covers ``beta=1.0`` (reverse KL) and ``beta=0.5``
(JSD).
Reproduction pre-fix: ``grad_norm=nan`` on step 1.
Reproduction post-fix: ``grad_norm`` finite; padding positions receive
zero gradient (correctly excluded from the learning signal).
A parallel audit of GKDTrainer confirmed it is not vulnerable to the same
class of bug: its teacher runs in-process on a dense rectangular batch,
with no HTTP ragged-to-rectangular reassembly and no -inf sentinel in the
GKD loss path.
Refs: huggingface#5407.
There was a problem hiding this comment.
Pull request overview
Fixes a NaN-gradient issue in the experimental distillation trainer’s server-backed reverse-KL / generalized JSD loss when batches contain variable completion lengths, by neutralizing -inf padding sentinels before divergence math runs.
Changes:
- Add masking in
_compute_server_sparse_top_1_divergence_lossto replace teacher-infsentinels atlabels == -100positions with finite zeros. - Clarify the
-infsentinel contract and where it is neutralized downstream. - Add a new regression test suite covering sentinel padding, finite forward/backward behavior, and an end-to-end
train()run with ragged completion lengths using a mockedVLLMClient.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
trl/experimental/distillation/distillation_trainer.py |
Neutralizes -inf sentinels at ignored label positions for the server reverse-KL/JSD path to prevent NaN gradients. |
tests/experimental/test_distillation_trainer.py |
Adds unit + functional regression tests validating the sentinel contract and guarding against non-finite backward passes under variable completion lengths. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Collapse the module summary, triple-line test docstrings, and the one-shot helper factories in `tests/experimental/test_distillation_trainer.py` into the repo's terse style. Functional coverage (sentinel pin, mid-level mask finite forward/backward, end-to-end train() under bs*ga>1 with ragged batches for beta=1.0 and beta=0.5) is unchanged; all 4 tests still pass.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Experiments showed the end-to-end regression tests were miscalibrated: - `bs=1, ga=2` and `bs=2, ga=1` both reproduce `grad_norm=nan` when the fix is removed (because `_get_teacher_token_logprobs_from_server` emits -inf padding not only for cross-sample ragged batches but also via per-sample `completion_offsets`). Parametrize the reverse-KL test over both configs for fuller coverage. - `beta=0.5` (JSD mixture) does not actually produce NaN without the fix in either config: `_jsd_divergence`'s mixture branch routes student gradients through `log((1-beta)*student_probs + beta*teacher_probs)`, which stays finite when teacher_probs=0 at padding. Drop the JSD end-to-end test — it was a vacuous guard. Unit + mid-level tests (sentinel contract, mask-keeps-forward-and- backward-finite) are unchanged.
Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
- Trim padding-mask comment to two lines focused on what it prevents; the backward-autograd exposition lived in the PR description. - Drop the explicit `zero` scalar tensor — `torch.where` broadcasts the `0.0` literal to the tensor's dtype/device (verified bit-exact equivalent in fp32/bf16/fp16). - Mark the end-to-end `trainer.train()` test `@pytest.mark.slow` to match repo convention for heavy tests (saves ~8s per warm CI run).
…uggingface#5538) Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
…plate (huggingface#5519) Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
huggingface#5523) Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
…ngface#5526) Co-authored-by: Rudrendu <RudrenduPaul@users.noreply.github.com> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Fixes a NaN-gradient bug in
DistillationTrainer's server-backed reverse-KL / generalized JSD loss when batches contain per-sample completion lengths that differ.Trigger:
use_teacher_server=True+beta > 0+per_device_train_batch_size * gradient_accumulation_steps > 1with variable completion lengths. Forward loss is finite (clamped bynan_to_num);grad_norm=nanon the first optim step.Root cause:
_get_teacher_token_logprobs_from_serverpads rectangular teacher logprobs with-inf. The forward-KL server path (_compute_server_forward_kl_loss) masks the sentinel before the divergence math viavalid = teacher > -inf+torch.where+ a support mask threaded through_add_tail_bucket. The reverse-KL path skips this masking. Unmasked-infflows through_add_tail_bucket(producing[-inf, 0]) and_jsd_divergence(producing+infin forward, clamped bynan_to_num, but NaN in backward — autograd's chain rule does not respectnan_to_num). Both paths landed together in huggingface#5407; the asymmetric masking looks like an oversight.Fix: In
_compute_server_sparse_top_1_divergence_loss, after the existingisfinitevalidation, neutralise the sentinel at known padding positions (labels == -100) with a finite zero viatorch.where, before the shared divergence helper runs. The label mask in_reduce_divergence_losscontinues to exclude these positions from the final loss.Tests: New
tests/experimental/test_distillation_trainer.py(trainer had no dedicated tests):_add_tail_bucket+_jsd_divergence(beta=1)post-mask, finite forward & backward,DistillationTrainer.train()atbs=1, ga=2with variable-length dataset and mockedVLLMClientforbeta=1.0andbeta=0.5.pytest tests/experimental/test_distillation_trainer.py -v: 4/4 pass in 28.12s.Env (
trl env):Before submitting
AI writing disclosure